# Standard Library Imports
import argparse
import os
import sys

# Third-Party Imports
import numpy as np
import torch
from sklearn.metrics import (
    auc,
    precision_recall_curve,
    roc_auc_score,
)
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm

# Project-Specific Imports
from dataset import *
from stable_diffusion import CustomStableDiffusionInpaintPipeline
from attn import *
from core import *
from utils import *


def main(args):
    device_id = args.device_id
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
    device = torch.device(f"cuda:{device_id}")

    if args.object in object_dictionary:
        object = object_dictionary[args.object]
    else:
        object = args.object

    model, _, _ = CLIPAD.create_model_and_transforms(model_name='ViT-B-16-plus-240', pretrained='laion400m_e31', precision='fp32')
    model = model.to(device)

    score_list = []
    gt_list = []

    cross_attn_init()

    pipe = CustomStableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
    ).to(device)
    pipe.unet = set_layer_with_name_and_path(pipe.unet)
    pipe.unet = register_cross_attention_hook(pipe.unet)

    mask_image = to_pil_image(torch.zeros((512, 512)))

    #################################################################
    if args.shot:
        train_dataset = CustomDataset(args.dataset, args.object, args.data_path, args.shot)
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=1, num_workers=64, shuffle=False)

        clip_memory_bank = build_clip_memory_bank(
            dataloader=train_dataloader,
            model=model,
            device=device,
        )

        diff_memory_bank = build_diff_memory_bank(
            dataloader=train_dataloader,
            pipe=pipe,
            mask_image=mask_image,
            object=object,
            template=args.vision_template,
            timesteps=args.vision_timesteps,
            blocks=args.vision_blocks,
        )
    #################################################################

    test_dataset = CustomDataset(args.dataset, args.object, args.data_path)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=64, shuffle=False)

    for image_path, mask_path in tqdm(test_dataloader):
        image_path = image_path[0]
        mask_path = mask_path[0]

        if mask_path == "":
            has_anomaly = torch.tensor([[0.]])
        else:
            has_anomaly = torch.tensor([[1.]])

        clip_language_score = get_clip_language_score(
            image_path=image_path,
            model=model,
            object=object,
            template=args.clip_template,
            device=device,
        )

        diff_language_score_map = get_diff_language_score_map(
            pipe=pipe,
            image_path=image_path,
            mask_image=mask_image,
            object=object,
            states=args.language_states,
            template=args.language_template,
            timesteps=args.language_timesteps,
            blocks=args.language_blocks,
        )
        diff_language_score = (1 - diff_language_score_map.median() / diff_language_score_map.max())

        #################################################################
        if args.shot:
            clip_vision_score_map = get_clip_vision_score_map(
                image_path=image_path,
                model=model,
                memory_bank=clip_memory_bank,
                device=device,
            )
            clip_vision_score = clip_vision_score_map.max()

            diff_vision_score_map = get_diff_vision_score_map(
                pipe=pipe,
                image_path=image_path,
                mask_image=mask_image,
                object=args.object,
                template=args.vision_template,
                timesteps=args.vision_timesteps,
                blocks=args.vision_blocks,
                memory_bank=diff_memory_bank,
            )
            diff_vision_score = diff_vision_score_map.max()
        #################################################################

        if args.shot:
            clip_score = clip_language_score + clip_vision_score
            diff_score = diff_language_score + diff_vision_score
            total_score = args.model_weight * clip_score + (1 - args.model_weight) * diff_score
        else:
            clip_score = clip_language_score
            diff_score = diff_language_score
            total_score = args.model_weight * clip_score + (1 - args.model_weight) * diff_score

        score_list.append(total_score.cpu().numpy())
        gt_list.append(has_anomaly[0].numpy())

    auroc = roc_auc_score(gt_list, score_list)
    precisions, recalls, _ = precision_recall_curve(gt_list, score_list)
    aupr = auc(recalls, precisions)

    return auroc, aupr


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device_id', type=int)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--task', type=str)
    parser.add_argument('--dataset', type=str, choices=['mvtec', 'visa'])
    parser.add_argument('--data_path', type=str)

    parser.add_argument('--shot', type=int)
    parser.add_argument('--model_weight', type=float)

    parser.add_argument('--clip_template', type=str)

    parser.add_argument('--language_template', type=str)
    parser.add_argument('--language_states', nargs='+', type=str)
    parser.add_argument('--language_timesteps', nargs='+', type=int)
    parser.add_argument('--language_blocks', nargs='+', type=str)

    parser.add_argument('--vision_template', type=str)
    parser.add_argument('--vision_timesteps', nargs='+', type=int)
    parser.add_argument('--vision_blocks', nargs='+', type=str)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_args()

    logger = setup_logger("test", base_dir=os.path.join("log", args.task, args.dataset), device_id=args.device_id)
    logger.info("Classification:\n%s", sys.argv)

    set_seed(args.seed)

    auroc_list = []
    aupr_list = []
    f1_max_list = []

    for object in get_objects(args.dataset):
        args.object = object
        auroc, aupr = main(args)
        
        logger.info("\n%s", "Object: {}, AUROC={}, AUPR={}".format(object, auroc, aupr))

        auroc_list.append(auroc)
        aupr_list.append(aupr)

    logger.info("\n%s", "AUROC={}, AUPR={}".format(np.mean(auroc_list), np.mean(aupr_list)))
    logger.info("Classification:\n%s", sys.argv)
